package com.chatgpt.api.conf;

import com.chatgpt.api.websocket.handler.ChatHandler;
import com.chatgpt.api.websocket.interceptor.WebsocketInterceptor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
    private final ChatHandler chatHandler;
    private final WebsocketInterceptor websocketInterceptor;

    public WebSocketConfig(ChatHandler chatHandler, WebsocketInterceptor websocketInterceptor) {
        this.chatHandler = chatHandler;
        this.websocketInterceptor = websocketInterceptor;
    }

    // 注册 WebSocket 消息处理 handler
    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry
                .addHandler(chatHandler, "/api/chat/new")
                .addInterceptors(websocketInterceptor) // 添加拦截器，预处理
                .setAllowedOrigins("*"); // 解决跨域问题
    }
}
